{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0db0cb1f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='bert-base-cased', vocab_size=28996, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "#加载分词工具\n",
    "tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')\n",
    "\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "02b6b4b5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/train/cache-8cfcfd522c905812.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/train/cache-b3e205139be224a8.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/train/cache-890424ad8ed86a19.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/train/cache-de855cf12202060d.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/validation/cache-1de25e78c7da64ff.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/validation/cache-620f9bf015596baf.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/validation/cache-0a0b01f98a68ed20.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/validation/cache-f7f25a94b8aafa52.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/test/cache-524ecec34978f89b.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/test/cache-86844412c2345cce.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/test/cache-632228277b828dc1.arrow\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " "
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading cached processed dataset at data/glue_sst2/test/cache-fd22071d68ee727d.arrow\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Dataset({\n",
       "    features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
       "    num_rows: 1000\n",
       "})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "from datasets import load_from_disk\n",
    "\n",
    "#加载数据集\n",
    "#从网络加载\n",
    "#datasets = load_dataset(path='glue', name='sst2')\n",
    "\n",
    "#从本地磁盘加载数据\n",
    "datasets = load_from_disk('./data/glue_sst2')\n",
    "\n",
    "\n",
    "#分词\n",
    "def f(data):\n",
    "    return tokenizer(\n",
    "        data['sentence'],\n",
    "        padding='max_length',\n",
    "        truncation=True,\n",
    "        max_length=30,\n",
    "    )\n",
    "\n",
    "\n",
    "datasets = datasets.map(f, batched=True, batch_size=1000, num_proc=4)\n",
    "\n",
    "#取数据子集，否则数据太多跑不动\n",
    "dataset_train = datasets['train'].shuffle().select(range(1000))\n",
    "dataset_test = datasets['validation'].shuffle().select(range(200))\n",
    "\n",
    "del datasets\n",
    "\n",
    "dataset_train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "74b91f77",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "10831.181\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoModelForSequenceClassification\n",
    "\n",
    "#加载模型\n",
    "model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased',\n",
    "                                                           num_labels=2)\n",
    "\n",
    "print(sum([i.nelement() for i in model.parameters()]) / 10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "876c0287",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'accuracy': 1.0}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from datasets import load_metric\n",
    "from transformers.trainer_utils import EvalPrediction\n",
    "\n",
    "#加载评价函数\n",
    "#有时会因为网络问题卡主,反复尝试会成功的\n",
    "metric = load_metric('accuracy')\n",
    "\n",
    "\n",
    "#定义评价函数\n",
    "def compute_metrics(eval_pred):\n",
    "    logits, labels = eval_pred\n",
    "    logits = logits.argmax(axis=1)\n",
    "    return metric.compute(predictions=logits, references=labels)\n",
    "\n",
    "\n",
    "#模拟测试输出\n",
    "eval_pred = EvalPrediction(\n",
    "    predictions=np.array([[0, 1], [2, 3], [4, 5], [6, 7]]),\n",
    "    label_ids=np.array([1, 1, 1, 1]),\n",
    ")\n",
    "\n",
    "compute_metrics(eval_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f2c59b49",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 200\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='14' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7/7 01:35]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.7613461017608643,\n",
       " 'eval_accuracy': 0.505,\n",
       " 'eval_runtime': 4.5246,\n",
       " 'eval_samples_per_second': 44.202,\n",
       " 'eval_steps_per_second': 1.547}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import TrainingArguments, Trainer\n",
    "\n",
    "#初始化训练参数\n",
    "args = TrainingArguments(output_dir='./output_dir',\n",
    "                         evaluation_strategy='epoch',\n",
    "                         no_cuda=True)\n",
    "args.num_train_epochs = 1\n",
    "args.learning_rate = 1e-4\n",
    "args.weight_decay = 1e-2\n",
    "args.per_device_eval_batch_size = 32\n",
    "args.per_device_train_batch_size = 16\n",
    "\n",
    "#初始化训练器\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=args,\n",
    "    train_dataset=dataset_train,\n",
    "    eval_dataset=dataset_test,\n",
    "    compute_metrics=compute_metrics,\n",
    ")\n",
    "\n",
    "#评价模型\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "104de6e9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "/root/anaconda3/envs/pt36/lib/python3.6/site-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
      "  FutureWarning,\n",
      "***** Running training *****\n",
      "  Num examples = 1000\n",
      "  Num Epochs = 1\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 63\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='63' max='63' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [63/63 01:30, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Epoch</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "      <th>Accuracy</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>1</td>\n",
       "      <td>No log</td>\n",
       "      <td>0.462459</td>\n",
       "      <td>0.790000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 200\n",
      "  Batch size = 32\n",
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=63, training_loss=0.5899859837123326, metrics={'train_runtime': 92.1116, 'train_samples_per_second': 10.856, 'train_steps_per_second': 0.684, 'total_flos': 15416663400000.0, 'train_loss': 0.5899859837123326, 'epoch': 1.0})"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#训练\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6630698b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: sentence, idx. If sentence, idx are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.\n",
      "***** Running Evaluation *****\n",
      "  Num examples = 200\n",
      "  Batch size = 32\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='7' max='7' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [7/7 00:03]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'eval_loss': 0.46245864033699036,\n",
       " 'eval_accuracy': 0.79,\n",
       " 'eval_runtime': 4.2879,\n",
       " 'eval_samples_per_second': 46.643,\n",
       " 'eval_steps_per_second': 1.633,\n",
       " 'epoch': 1.0}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#评价模型\n",
    "trainer.evaluate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "40bddaa5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Saving model checkpoint to ./output_dir\n",
      "Configuration saved in ./output_dir/config.json\n",
      "Model weights saved in ./output_dir/pytorch_model.bin\n"
     ]
    }
   ],
   "source": [
    "#保存模型\n",
    "trainer.save_model(output_dir='./output_dir')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "2ed86d9d",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([1, 0, 1, 1]),\n",
       " tensor([[  101,  1177,  1277,   175,  7409,  4759,  5531,   117,  1216, 10509,\n",
       "           4133,   117,  1177,  1376,  2523,   119,   102,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  1175,  1132,  4928,  7996,  1992,  1536,  1111,   188,  2522,\n",
       "           1358,  1103,  7010, 17757,  1106, 11231,  1194,   119,   102,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  1209,  1821,  5613,  1105,  5250, 14638, 16889, 25576,  6323,\n",
       "           1107, 13858,  9165,   119,   102,     0,     0,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0],\n",
       "         [  101,  1141,  1104,  1103,  1167,  9998,  1482,   112,   188,  5558,\n",
       "           1106,  1855, 13090,  1142,  1214,   119,   102,     0,     0,     0,\n",
       "              0,     0,     0,     0,     0,     0,     0,     0,     0,     0]]),\n",
       " tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0]]),\n",
       " tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0],\n",
       "         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,\n",
       "          0, 0, 0, 0, 0, 0]]))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "def collate_fn(data):\n",
    "    label = [i['label'] for i in data]\n",
    "    input_ids = [i['input_ids'] for i in data]\n",
    "    token_type_ids = [i['token_type_ids'] for i in data]\n",
    "    attention_mask = [i['attention_mask'] for i in data]\n",
    "\n",
    "    label = torch.LongTensor(label)\n",
    "    input_ids = torch.LongTensor(input_ids)\n",
    "    token_type_ids = torch.LongTensor(token_type_ids)\n",
    "    attention_mask = torch.LongTensor(attention_mask)\n",
    "\n",
    "    return label, input_ids, token_type_ids, attention_mask\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "loader_test = torch.utils.data.DataLoader(dataset=dataset_test,\n",
    "                                          batch_size=4,\n",
    "                                          collate_fn=collate_fn,\n",
    "                                          shuffle=True,\n",
    "                                          drop_last=True)\n",
    "\n",
    "for i, (label, input_ids, token_type_ids,\n",
    "        attention_mask) in enumerate(loader_test):\n",
    "    break\n",
    "\n",
    "label, input_ids, token_type_ids, attention_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "19fd71e0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.75"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "#测试\n",
    "def test():\n",
    "    #加载参数\n",
    "    model.load_state_dict(torch.load('./output_dir/pytorch_model.bin'))\n",
    "\n",
    "    model.eval()\n",
    "\n",
    "    #运算\n",
    "    out = model(input_ids=input_ids,\n",
    "                token_type_ids=token_type_ids,\n",
    "                attention_mask=attention_mask)\n",
    "\n",
    "    #[4, 2] -> [4]\n",
    "    out = out['logits'].argmax(dim=1)\n",
    "\n",
    "    correct = (out == label).sum().item()\n",
    "\n",
    "    return correct / len(label)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
